"""
inventor centrality measures based on co-invention experience
Date: 2023-03-20
Author: Shaoyu Liu

example code:
    cd /Volumes/Zihao_SSD2/PatentsView    
    python3 code/inventor_centrality.py \
    --path_inventor rawdata/g_inventor_disambiguated.tsv \
    --path_patent rawdata/g_patent.tsv \
    --year 1981 \
    --output_dir /centrality_results
"""

#coauthorship network
import pandas as pd
import numpy as np
import networkx as nx
import os
import logging
import argparse


def get_inventor(path_inventor, path_patent, year):
    """
    Args:
        path_inventor: path to g_inventor_disambiguated.tsv
        path_patent: path to g_patent.tsv
        year: graph for a specified year
    Returns:
        
    """
    df = pd.read_csv(path_inventor, delimiter="\t", usecols=['patent_id', 'inventor_id', 'location_id'], low_memory=False)
    df_patent=pd.read_csv(path_patent, delimiter="\t", usecols=['patent_id', 'patent_date'], low_memory=False)
    df_patent.rename(columns={'patent_date':'year'}, inplace=True)
    df_patent['year'] = df_patent['year'].str[:4].astype("int16")
    
    df = df.merge(df_patent, on=['patent_id'], how="inner")
    min_yr = df['year'].min()
    logger.info(f'min year:{min_yr}')
    df = df[df['year'] < year]

    df2 = df[['patent_id','inventor_id']].merge(df[['patent_id','inventor_id']], on = ['patent_id'])
    logger.info(len(df2))

    logger.info("drop self-loops")
    df2 = df2[df2['inventor_id_x'] != df2['inventor_id_y']]
    logger.info(len(df2))
    
    df2 = df2.groupby(['inventor_id_x','inventor_id_y']).size().reset_index()
    df2.columns=['inventor_id_x','inventor_id_y','weight']

    #construct a graph
    #each row being a pair of inventors (inventor A and inventor B) co-occur in a patent.
    G = nx.Graph()
    G = nx.from_pandas_edgelist(df2, source='inventor_id_x', target='inventor_id_y', edge_attr=['weight'])
    logger.info(f'number of edges: {G.number_of_edges()}')
    logger.info(f'number of nodes: {G.number_of_nodes()}')

    #get isolated nodes
    set_isolated=tuple(set(df['inventor_id']).difference(set(df2['inventor_id_x'])))
    logger.info(f'number of isolated inventors: {len(set_isolated)}')

    G.add_node(set_isolated)
    all_inventors = set(df['inventor_id'])

    return G, all_inventors


def compute_centrality(graph, year, output_dir, method, all_inventors) -> None:
    """
    compute centrality measures
    Args:
        closeness centrality Time Complexity: O(VE + V^2)
        degree_centrality: O(V^2)
        betweenness centrality: O(VE + V^2)
    Returns:
        None
    """
    logger.info("start computing centrality")
    #cls_c = nx.closeness_centrality(graph)
    deg_c = nx.degree_centrality(graph)
    #bet_c = nx.betweenness_centrality(G, normalized = True, endpoints = False)
    logger.info("done computing centrality")
    
    # Create a DataFrame with all inventors
    c1 = pd.DataFrame({'inventor_id': list(all_inventors)})
    
    # Map the centrality values to the inventor IDs and fill missing values with 0
    c1['deg_centrality'] = c1['inventor_id'].map(deg_c).fillna(0)
    
    output_path = os.path.join(output_dir, f"centrality_{year}_{method}.csv")
    c1.to_csv(output_path, index=False)
    logger.info(f'{method} centrality measure for {year} saved!')
    

def main(path_inventor, path_patent, year, output_dir, method="closeness"):
    """
    ...
    """
    print('Year = ' + str(year))
    G, all_inventors = get_inventor(path_inventor, path_patent, year)
    compute_centrality(graph=G, year=year, output_dir=output_dir, method=method, all_inventors=all_inventors)


if __name__ == '__main__':
    os.chdir(r'/Volumes/Zihao_SSD2/PatentsView/')

    ap = argparse.ArgumentParser()
    ap.add_argument('--path_inventor', help='specify path to inventor')
    ap.add_argument('--path_patent', help='specify path to patent')
    ap.add_argument('--year', type=int)
    ap.add_argument('--output_dir')
    args = ap.parse_args()
    
    logging.basicConfig(
        filename=os.path.join('logs',f'centrality_{args.year}.log'), 
        format='%(asctime)s:%(levelname)s:%(message)s', 
        level=logging.INFO
    )
    logger = logging.getLogger(__name__)
    
    main(path_inventor=args.path_inventor,
         path_patent=args.path_patent, 
         year=args.year, 
         output_dir=args.output_dir, 
         method="degree")